# -*- coding: utf-8 -*-
"""
Created on Sat Dec  5 16:23:54 2020

@author: Team317
"""
import torch
import torch.nn as nn
# from resnet100 import KitModel
from age_net.resnet100 import KitModel

class agenet(KitModel):
    def __init__(self, weight_file):
        super(agenet, self).__init__(weight_file)
        
        self.fc2 = nn.Linear(512, 128)
        self.tanh = nn.Tanh()
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 8)
        
        self.fc5 = nn.Linear(8, 1)
        
        
    def forward(self, inputs):
        net = KitModel.forward(self, inputs)
        
        fc2 = self.fc2(net)
        # fc3 = self.fc3((self.tanh(fc2)) + 1)*5
        fc3 = self.fc3((self.tanh(fc2/1000)) + 1)*5
        fc4 = self.fc4(fc3)
        fc5 = self.fc5(fc4)
        
        
        return fc5
        